// -*- mode:c++ -*-

// Copyright (c) 2022 PLCT Lab
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met: redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer;
// redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution;
// neither the name of the copyright holders nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


let {{
def setVlenb():
        return "[[maybe_unused]] uint32_t vlenb = vlen >> 3;\n"

def declareVMemTemplate(class_name):
    return f'''
    template class {class_name}<uint8_t>;
    template class {class_name}<uint16_t>;
    template class {class_name}<uint32_t>;
    template class {class_name}<uint64_t>;
    '''

def getFaultCode():
    return '''
    Addr fault_addr;
    if (fault != NoFault && getFaultVAddr(fault, fault_addr)) {
        assert(fault_addr >= EA);
        faultIdx = (fault_addr - EA) / (width_EEW(machInst.width) / 8);
        if (microIdx != 0 || faultIdx != 0) {
            fault = NoFault;
            trimVl = true;
        }
    }
    '''

def getTailMaskPolicyCode():
    return '''
    if (!machInst.vtype8.vta || (!machInst.vm && !machInst.vtype8.vma)) {
        RiscvISA::vreg_t old_vd;
        xc->getRegOperand(this, 1, &old_vd);
        tmp_d0 = old_vd;
    } else {
        tmp_d0.set(0xff);
    }
    '''

def VMemBase(name, Name, ea_code, memacc_code, mem_flags,
                   inst_flags, base_class, postacc_code='',
                   declare_template_base=VMemMacroDeclare,
                   decode_template=VMemBaseDecodeBlock, exec_template_base='',
                   microop_template_base='',
                   # If it's a macroop, the corresponding microops will be
                   # generated.
                   is_macroop=True):
    # Make sure flags are in lists (convert to lists if not).
    mem_flags = makeList(mem_flags)
    inst_flags = makeList(inst_flags)
    iop = InstObjParams(name, Name, base_class,
        {'ea_code': ea_code,
         'memacc_code': memacc_code,
         'postacc_code': postacc_code,
         'declare_vmem_template': declareVMemTemplate(Name)},
        inst_flags)

    constructTemplate = eval(exec_template_base + 'Constructor')
    if not microop_template_base:
        microop_template_base = exec_template_base

    header_output   = declare_template_base.subst(iop)
    decoder_output  = constructTemplate.subst(iop)
    decode_block    = decode_template.subst(iop)
    exec_output     = ''
    if not is_macroop:
        return (header_output, decoder_output, decode_block, exec_output)

    micro_class_name = microop_template_base + 'MicroInst'

    fault_only_first = 'FaultOnlyFirst' in iop.op_class

    microiop = InstObjParams(name + '_micro',
        Name + 'Micro',
        microop_template_base + 'MicroInst',
        {'ea_code': ea_code,
         'memacc_code': memacc_code,
         'postacc_code': postacc_code,
         'set_vlenb': setVlenb(),
         'declare_vmem_template': declareVMemTemplate(Name + 'Micro'),
         'fault_code': getFaultCode() if fault_only_first else '',
         'tail_mask_policy_code': getTailMaskPolicyCode()},
        inst_flags)

    if mem_flags:
        mem_flags = [ '(Request::%s)' % flag for flag in mem_flags ]
        s = '\n\tmemAccessFlags = ' + '|'.join(mem_flags) + ';'
        microiop.constructor += s

    microDeclTemplate = eval(microop_template_base + 'Micro' + 'Declare')
    microConsTemplate = eval(microop_template_base + 'Micro' + 'Constructor')
    microExecTemplate = eval(microop_template_base + 'Micro' + 'Execute')
    microInitTemplate = eval(microop_template_base + 'Micro' + 'InitiateAcc')
    microCompTemplate = eval(microop_template_base + 'Micro' + 'CompleteAcc')
    header_output = microDeclTemplate.subst(microiop) + header_output
    decoder_output = microConsTemplate.subst(microiop) + decoder_output
    micro_exec_output = (microExecTemplate.subst(microiop) +
        microInitTemplate.subst(microiop) +
        microCompTemplate.subst(microiop))
    exec_output += micro_exec_output

    return (header_output, decoder_output, decode_block, exec_output)

}};

def format VleOp(
    memacc_code,
    ea_code={{
        EA = Rs1 + vlenb * microIdx;
    }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VleMacroInst', exec_template_base='Vle')
}};

def format VseOp(
    memacc_code,
    ea_code={{
        EA = Rs1 + vlenb * microIdx;
    }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VseMacroInst', exec_template_base='Vse')
}};

def format VlmOp(
    memacc_code,
    ea_code={{ EA = Rs1; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VleMacroInst', exec_template_base='Vlm', is_macroop=False)
}};

def format VsmOp(
  memacc_code,
  ea_code={{ EA = Rs1; }},
  mem_flags=[],
  inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VseMacroInst', exec_template_base='Vsm', is_macroop=False)
}};

def format VlWholeOp(
    memacc_code,
    ea_code={{
        EA = Rs1 + vlenb * microIdx;
    }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VlWholeMacroInst', exec_template_base='VlWhole')
}};

def format VsWholeOp(
    memacc_code,
    ea_code={{
        EA = Rs1 + vlenb * microIdx;
    }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VsWholeMacroInst', exec_template_base='VsWhole')
}};

def format VlStrideOp(
    memacc_code,
    ea_code={{
        EA = Rs1 + Rs2 * ei + offset;
    }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VlElementMacroInst', exec_template_base='VlStride',
                 microop_template_base="VlElement")
}};

def format VsStrideOp(
    memacc_code,
    ea_code={{
        EA = Rs1 + Rs2 * ei + offset;
    }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VsElementMacroInst', exec_template_base='VsStride',
                 microop_template_base="VsElement")
}};

def format VlIndexOp(
    memacc_code,
    ea_code,
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VlIndexMacroInst', exec_template_base='VlIndex',
                 declare_template_base=VMemTemplateMacroDeclare,
                 decode_template=VMemSplitTemplateDecodeBlock
                 )
}};

def format VsIndexOp(
    memacc_code,
    ea_code,
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VsIndexMacroInst', exec_template_base='VsIndex',
                 declare_template_base=VMemTemplateMacroDeclare,
                 decode_template=VMemSplitTemplateDecodeBlock
                 )
}};

def format VlSegOp(
    memacc_code,
    ea_code={{
        EA = Rs1 + mem_size * (microIdx + (field * numMicroops));
    }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VlSegMacroInst', exec_template_base='VlSeg')
}};

def format VsSegOp(
    memacc_code,
    ea_code={{
        EA = Rs1 + mem_size * (microIdx + (field * numMicroops));
    }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                 'VsSegMacroInst', exec_template_base='VsSeg')
}};
